In [24]:
import numpy as np
import pandas as pd
In [256]:
def transform(points_x):
n = points_x.shape[0]
m = 8
points_z = np.zeros([n, m])
points_z[:,0] = 1 # add intercept
points_z[:,1:3] = points_x[:,0:2] # copy over the original features
points_z[:,3:5] = np.power(points_x[:, 0:2], 2) # powers
points_z[:,5] = points_x[:, 0] * points_x[:, 1] # product
points_z[:,6] = np.abs(points_x[:, 0] - points_x[:, 1]) # abs diff
points_z[:,7] = np.abs(points_x[:, 0] + points_x[:, 1]) # abs sum
return points_z
def solve_linear_regression(pi, labels):
return np.dot(np.dot(np.linalg.pinv(np.dot(pi.T, pi)), pi.T), labels)
def get_linear_regression_error(pi, labels, gx_vector):
predictions = np.sign(np.dot(pi, gx_vector))
return sum ((labels * predictions) < 0) / len(labels)
In [36]:
train = pd.read_csv("./data/in.dta", delim_whitespace=True, header=None).as_matrix()
test = pd.read_csv("./data/out.dta", delim_whitespace=True, header=None).as_matrix()
In [86]:
train_x = train[:,:2]
train_y = train[:,2]
test_x = test[:,:2]
test_y = test[:,2]
In [259]:
train_z = transform(train_x)
test_z = transform(test_x)
In [285]:
gx_vector = solve_linear_regression(train_z, train_y) # linear regression without decay
error_in = get_linear_regression_error(train_z, train_y, gx_vector) # error in sample
error_out = get_linear_regression_error(test_z, test_y, gx_vector) # error out of sample
print("error_in = %.3f\t error_out = %.3f" % (error_in, error_out))
In [286]:
train_z.shape
Out[286]:
In [296]:
def solve_linear_regression_with_decay(pi, labels, lmd):
LI = lmd * np.eye(pi.shape[1])
return np.dot(np.dot(np.linalg.pinv(np.dot(pi.T, pi) + LI), pi.T), labels)
In [314]:
for k in range(-3, 4, 1):
gx_vector = solve_linear_regression_with_decay(train_z, train_y, pow(10, k)) # linear regression with decay
error_in = get_linear_regression_error(train_z, train_y, gx_vector) # error in sample
error_out = get_linear_regression_error(test_z, test_y, gx_vector) # error out of sample
print("k = %d\t error_in = %.3f\t error_out = %.3f" %(k, error_in, error_out))